{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "pytorch_geometric_edge_conv_test.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "F4y_11GUfywX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "import torch_geometric\n",
        "import numpy as np"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "P3PVUXPhf45S",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from torch.nn import Sequential as Seq, Linear as Lin, ReLU\n",
        "from torch_geometric.nn import MessagePassing\n",
        "\n",
        "class EdgeConv(MessagePassing):\n",
        "    def __init__(self, F_in, F_out):\n",
        "        super(EdgeConv, self).__init__(aggr='max',flow = 'target_to_source')  # \"Max\" aggregation.\n",
        "        self.mlp = Seq(Lin(2 * F_in, F_out), ReLU(), Lin(F_out, F_out))\n",
        "\n",
        "    def forward(self, x,edge_index):\n",
        "        # x has shape [N, F_in]\n",
        "        # edge_index has shape [2, E]\n",
        "        print('in forward x is :{}'.format(x))\n",
        "        h = x*1.5\n",
        "        return self.propagate(edge_index, x=x, h = h)  # shape [N, F_out]\n",
        "\n",
        "    def message(self, x_i, x_j,h, h_i, h_j):\n",
        "        print('x_i:{}'.format(x_i))\n",
        "        print('x_j:{}'.format(x_j))\n",
        "        print('h  :{}'.format(h))\n",
        "        print('h i :{}'.format(h_i))\n",
        "        print('h j :{}'.format(h_j))\n",
        "        # x_i has shape [E, F_in]\n",
        "        # x_j has shape [E, F_in]\n",
        "        edge_features = torch.cat([x_i, x_j - x_i], dim=1)  # shape [E, 2 * F_in]\n",
        "        print('edge_features :{}'.format(edge_features))\n",
        "        print('mlp edge_fearutes :{}'.format(self.mlp(edge_features)))\n",
        "        return self.mlp(edge_features)  # shape [E, F_out]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XYNuUBPGgUYD",
        "colab_type": "code",
        "outputId": "cb11ef92-e7de-4e6a-b2a5-d4405da4887b",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 971
        }
      },
      "source": [
        "# in_channels = 16\n",
        "# edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],\n",
        "#                                [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]])\n",
        "# num_nodes = edge_index.max().item() + 1\n",
        "# x = torch.randn((num_nodes,in_channels))\n",
        "\n",
        "# random seed \n",
        "SEED = 110 \n",
        "np.random.seed(SEED) \n",
        "torch.manual_seed(SEED) \n",
        "torch.cuda.manual_seed(SEED)\n",
        "\n",
        "edge_index_ = torch.tensor([[0,1,1,2,2,3],[1,0,2,1,3,2]])\n",
        "#'target_to_source'\n",
        "x = torch.randn((4,2))\n",
        "\n",
        "print('node feature:{} shape {}'.format(x, x.size()))\n",
        "print('node num is :{}'.format(x.size()[0]))\n",
        "edge_conv = EdgeConv(x.size()[1], x.size()[1]*2)\n",
        "hidden_feature = edge_conv(x, edge_index_)\n",
        "print('node feature:{} shape {}'.format(hidden_feature, hidden_feature.size()))\n",
        "\n"
      ],
      "execution_count": 74,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "node feature:tensor([[-0.4542, -1.1170],\n",
            "        [ 0.7534,  2.6707],\n",
            "        [-1.2473, -0.8668],\n",
            "        [ 0.1647, -0.1978]]) shape torch.Size([4, 2])\n",
            "node num is :4\n",
            "in forward x is :tensor([[-0.4542, -1.1170],\n",
            "        [ 0.7534,  2.6707],\n",
            "        [-1.2473, -0.8668],\n",
            "        [ 0.1647, -0.1978]])\n",
            "x_i:tensor([[-0.4542, -1.1170],\n",
            "        [ 0.7534,  2.6707],\n",
            "        [ 0.7534,  2.6707],\n",
            "        [-1.2473, -0.8668],\n",
            "        [-1.2473, -0.8668],\n",
            "        [ 0.1647, -0.1978]])\n",
            "x_j:tensor([[ 0.7534,  2.6707],\n",
            "        [-0.4542, -1.1170],\n",
            "        [-1.2473, -0.8668],\n",
            "        [ 0.7534,  2.6707],\n",
            "        [ 0.1647, -0.1978],\n",
            "        [-1.2473, -0.8668]])\n",
            "h  :tensor([[-0.6812, -1.6755],\n",
            "        [ 1.1302,  4.0060],\n",
            "        [-1.8710, -1.3002],\n",
            "        [ 0.2470, -0.2968]])\n",
            "h i :tensor([[-0.6812, -1.6755],\n",
            "        [ 1.1302,  4.0060],\n",
            "        [ 1.1302,  4.0060],\n",
            "        [-1.8710, -1.3002],\n",
            "        [-1.8710, -1.3002],\n",
            "        [ 0.2470, -0.2968]])\n",
            "h j :tensor([[ 1.1302,  4.0060],\n",
            "        [-0.6812, -1.6755],\n",
            "        [-1.8710, -1.3002],\n",
            "        [ 1.1302,  4.0060],\n",
            "        [ 0.2470, -0.2968],\n",
            "        [-1.8710, -1.3002]])\n",
            "edge_features :tensor([[-0.4542, -1.1170,  1.2076,  3.7877],\n",
            "        [ 0.7534,  2.6707, -1.2076, -3.7877],\n",
            "        [ 0.7534,  2.6707, -2.0007, -3.5374],\n",
            "        [-1.2473, -0.8668,  2.0007,  3.5374],\n",
            "        [-1.2473, -0.8668,  1.4120,  0.6689],\n",
            "        [ 0.1647, -0.1978, -1.4120, -0.6689]])\n",
            "mlp edge_fearutes :tensor([[ 0.4431,  0.3397,  0.1702, -0.4085],\n",
            "        [ 0.2105, -1.0505,  0.1407, -1.5917],\n",
            "        [ 0.3746, -1.1566,  0.2734, -1.5619],\n",
            "        [ 0.4180,  0.3458,  0.1456, -0.4276],\n",
            "        [ 0.3594,  0.2429,  0.0281, -0.5120],\n",
            "        [ 0.4252, -0.1909,  0.0702, -0.7635]], grad_fn=<AddmmBackward>)\n",
            "node feature:tensor([[ 0.4431,  0.3397,  0.1702, -0.4085],\n",
            "        [ 0.3746, -1.0505,  0.2734, -1.5619],\n",
            "        [ 0.4180,  0.3458,  0.1456, -0.4276],\n",
            "        [ 0.4252, -0.1909,  0.0702, -0.7635]], grad_fn=<IndexPutBackward>) shape torch.Size([4, 4])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Urb02NjZQX4G",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 881
        },
        "outputId": "baaa2cfd-0f49-4d58-8bbf-dd7e0ac0e96c"
      },
      "source": [
        "#edge_index_ = torch.tensor([[0,1,1,2],[1,2,3,3]])\n",
        "for parameters in edge_conv.named_parameters():\n",
        "    print(parameters)"
      ],
      "execution_count": 53,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "('mlp.0.weight', Parameter containing:\n",
            "tensor([[ 0.2339, -0.1264, -0.0869,  0.0030, -0.0564, -0.2219, -0.1512, -0.2413,\n",
            "          0.2284, -0.2534],\n",
            "        [ 0.2174,  0.0612,  0.1780, -0.1818, -0.0561, -0.2301,  0.2644,  0.2187,\n",
            "          0.0921, -0.2705],\n",
            "        [ 0.1420, -0.3138,  0.2522, -0.0327,  0.2440, -0.2843, -0.1267,  0.1805,\n",
            "          0.1376,  0.1444],\n",
            "        [ 0.2422,  0.2226,  0.2828,  0.0410,  0.0326,  0.1039, -0.1294,  0.1071,\n",
            "          0.0511,  0.1972],\n",
            "        [-0.2804,  0.0284,  0.1446,  0.1746, -0.2075,  0.0478,  0.1570,  0.2182,\n",
            "          0.2310,  0.2890],\n",
            "        [-0.2611, -0.1077, -0.0249,  0.1980, -0.1131,  0.1167,  0.1708, -0.2562,\n",
            "          0.1654, -0.2475],\n",
            "        [-0.2213,  0.1678, -0.0679,  0.3111,  0.3126,  0.2350,  0.2912,  0.1298,\n",
            "          0.2645, -0.1921],\n",
            "        [ 0.0991, -0.0082,  0.2347, -0.2957, -0.1309,  0.1926,  0.0125, -0.0239,\n",
            "          0.1034,  0.0145],\n",
            "        [ 0.1885, -0.1790,  0.3056,  0.0712,  0.2580, -0.1210, -0.0615, -0.0611,\n",
            "          0.0570, -0.0844],\n",
            "        [ 0.1230, -0.1383,  0.1476,  0.2305, -0.2721, -0.0053, -0.0613,  0.1574,\n",
            "          0.0017,  0.0980]], requires_grad=True))\n",
            "('mlp.0.bias', Parameter containing:\n",
            "tensor([-0.2110, -0.1835, -0.1959, -0.0823,  0.2996,  0.2757, -0.0217,  0.1970,\n",
            "         0.2497,  0.2583], requires_grad=True))\n",
            "('mlp.2.weight', Parameter containing:\n",
            "tensor([[-0.2413,  0.1094,  0.2596,  0.0628, -0.3082, -0.2323, -0.3098, -0.0365,\n",
            "         -0.2900, -0.3140],\n",
            "        [-0.2843,  0.2262, -0.0223, -0.1978,  0.2194, -0.3073, -0.2528, -0.2523,\n",
            "         -0.2183, -0.0563],\n",
            "        [ 0.0920, -0.0771,  0.1916, -0.2627,  0.2995, -0.0692,  0.2682, -0.1337,\n",
            "          0.2309,  0.1829],\n",
            "        [ 0.2918, -0.0528, -0.2685,  0.0459,  0.0387, -0.1977, -0.0086, -0.1652,\n",
            "         -0.2710, -0.0028],\n",
            "        [-0.1893,  0.0744, -0.1683,  0.3037, -0.1887, -0.1031, -0.1407, -0.0222,\n",
            "          0.1331,  0.2675],\n",
            "        [ 0.0574,  0.0075,  0.2512, -0.0783,  0.0900,  0.0342,  0.1862, -0.1559,\n",
            "         -0.0247,  0.0647],\n",
            "        [-0.0487, -0.2114,  0.1277, -0.1497,  0.3013,  0.1365,  0.2188, -0.0927,\n",
            "         -0.1976,  0.1318],\n",
            "        [-0.2422,  0.1442, -0.3080,  0.0412, -0.1960, -0.1528,  0.2207,  0.0611,\n",
            "          0.1328, -0.0794],\n",
            "        [-0.1568, -0.1540, -0.2523,  0.3159,  0.3109,  0.2662,  0.0114, -0.3075,\n",
            "          0.2552, -0.1213],\n",
            "        [-0.2692, -0.0040,  0.2898,  0.0185,  0.2445,  0.1733,  0.2510,  0.1060,\n",
            "          0.2154,  0.1866]], requires_grad=True))\n",
            "('mlp.2.bias', Parameter containing:\n",
            "tensor([-0.0084,  0.2636, -0.0709, -0.0688, -0.2840,  0.0220, -0.2668, -0.0687,\n",
            "         0.2403,  0.0824], requires_grad=True))\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ClPniA20GJ_H",
        "colab_type": "code",
        "outputId": "d83223f0-6d0c-4f68-8b78-551a981bc2a6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 161
        }
      },
      "source": [
        "\n",
        "from torch_geometric.nn import SAGEConv\n",
        "SEED = 110 \n",
        "np.random.seed(SEED) \n",
        "torch.manual_seed(SEED) \n",
        "torch.cuda.manual_seed(SEED)\n",
        "data = torch.randn((4,5))\n",
        "in_channels = data.size(1)\n",
        "\n",
        "\n",
        "conv1 = SAGEConv(in_channels, in_channels*2, normalize=True, bias=True)\n",
        "\n",
        "\n",
        "edge_index = torch.tensor([[0,1,2,3],[1,2,0,1]])\n",
        "out1 = conv1(data,edge_index)\n",
        "print(out1)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tensor([[-0.1716,  0.0126, -0.1496,  0.3717, -0.3207,  0.2212, -0.0057, -0.6195,\n",
            "          0.3733, -0.3672],\n",
            "        [ 0.1158,  0.0813, -0.0526,  0.1341, -0.2646,  0.0453,  0.1727, -0.5094,\n",
            "          0.3050, -0.7105],\n",
            "        [-0.4656, -0.2441, -0.3872,  0.1599, -0.2008,  0.0641,  0.2571, -0.4470,\n",
            "          0.0925, -0.4789],\n",
            "        [-0.0220,  0.5372,  0.1183,  0.2815, -0.3936,  0.4917,  0.1792, -0.3118,\n",
            "          0.1870, -0.2382]], grad_fn=<DivBackward0>)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RZj_JkyrHv8N",
        "colab_type": "code",
        "outputId": "39baa303-c4e8-46f2-dafe-ab2f3af6f185",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 269
        }
      },
      "source": [
        "for parameters in conv1.named_parameters():\n",
        "    print(parameters)\n"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "('weight', Parameter containing:\n",
            "tensor([[ 0.3308, -0.1787, -0.1229,  0.0042, -0.0797, -0.3138, -0.2138, -0.3413,\n",
            "          0.3229, -0.3584],\n",
            "        [ 0.3074,  0.0865,  0.2517, -0.2572, -0.0793, -0.3254,  0.3739,  0.3093,\n",
            "          0.1302, -0.3825],\n",
            "        [ 0.2008, -0.4437,  0.3567, -0.0463,  0.3451, -0.4021, -0.1792,  0.2553,\n",
            "          0.1946,  0.2043],\n",
            "        [ 0.3425,  0.3147,  0.4000,  0.0580,  0.0461,  0.1469, -0.1831,  0.1515,\n",
            "          0.0723,  0.2788],\n",
            "        [-0.3966,  0.0401,  0.2044,  0.2469, -0.2935,  0.0677,  0.2220,  0.3087,\n",
            "          0.3266,  0.4087]], requires_grad=True))\n",
            "('bias', Parameter containing:\n",
            "tensor([-0.3693, -0.1523, -0.0352,  0.2800, -0.1599,  0.1650,  0.2416, -0.3623,\n",
            "         0.2339, -0.3501], requires_grad=True))\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}